{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本生成实战"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "def solve_cudnn_error():\n",
    "    gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "    if gpus:\n",
    "        try:\n",
    "            # Currently, memory growth needs to be the same across GPUs\n",
    "            for gpu in gpus:\n",
    "                tf.config.experimental.set_memory_growth(gpu, True)\n",
    "            logical_gpus = tf.config.experimental.list_logical_devices('GPU')\n",
    "            print(len(gpus), \"Physical GPUs,\", len(logical_gpus), \"Logical GPUs\")\n",
    "        except RuntimeError as e:\n",
    "            # Memory growth must be set before GPUs have been initialized\n",
    "            print(e)\n",
    "\n",
    "solve_cudnn_error()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "source": [
    "input_filepath = './data/shakespeare.txt'\n",
    "text = open(input_filepath, 'r').read()\n",
    "\n",
    "print(len(text))\n",
    "print(text[0:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理\n",
    "1. 生成词表 generator vocab\n",
    "2. 建立字符到id的映射 build mapping char->id\n",
    "3. 将数据转化为id data->id_data\n",
    "4. 预测下一个字符的模型，abcd -> bcd\\<eos>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "source": [
    "# 1.generator vocab\n",
    "vocab = sorted(set(text)) # 把字符串存为set，则只剩下唯一的字符\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n"
     ]
    }
   ],
   "source": [
    "# 2.build mapping char->id\n",
    "char2idx = {char : idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "source": [
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "source": [
    "# 3.data->id_data\n",
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(18, shape=(), dtype=int32) F\n",
      "tf.Tensor(47, shape=(), dtype=int32) i\n",
      "tf.Tensor(\n",
      "[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43\n",
      "  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43\n",
      " 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49\n",
      "  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10\n",
      "  0 37 53 59  1], shape=(101,), dtype=int32)\n",
      "'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n",
      "tf.Tensor(\n",
      "[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1\n",
      " 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0\n",
      " 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8\n",
      "  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1\n",
      " 63 53 59  1 49], shape=(101,), dtype=int32)\n",
      "'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n"
     ]
    }
   ],
   "source": [
    "# 4.定义输入输出\n",
    "def split_input_target(id_text):\n",
    "    \"\"\"abcde -> abcd, bcde\"\"\"\n",
    "    return id_text[0:-1], id_text[1:]\n",
    "\n",
    "# 定义dataset\n",
    "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
    "seq_length = 100\n",
    "seq_dataset = char_dataset.batch(seq_length + 1, drop_remainder = True)\n",
    "\n",
    "# test\n",
    "for ch_id in char_dataset.take(2):\n",
    "    print(ch_id, idx2char[ch_id.numpy()])\n",
    "    \n",
    "for seq_id in seq_dataset.take(2):\n",
    "    print(seq_id)\n",
    "    print(repr(''.join(idx2char[seq_id.numpy()])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43\n",
      "  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43\n",
      " 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49\n",
      "  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10\n",
      "  0 37 53 59]\n",
      "[47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43  1\n",
      " 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43 39\n",
      " 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49  6\n",
      "  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0\n",
      " 37 53 59  1]\n",
      "[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1\n",
      " 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0\n",
      " 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8\n",
      "  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1\n",
      " 63 53 59  1]\n",
      "[56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1 58\n",
      " 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0 13\n",
      " 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8  0\n",
      "  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1 63\n",
      " 53 59  1 49]\n"
     ]
    }
   ],
   "source": [
    "seq_dataset = seq_dataset.map(split_input_target)\n",
    "\n",
    "for item_input, item_output in seq_dataset.take(2):\n",
    "    print(item_input.numpy())\n",
    "    print(item_output.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 给dataset做变换\n",
    "batch_size = 64\n",
    "buffer_size = 10000\n",
    "\n",
    "seq_dataset = seq_dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (64, None, 256)           16640     \n",
      "_________________________________________________________________\n",
      "simple_rnn (SimpleRNN)       (64, None, 1024)          1311744   \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (64, None, 65)            66625     \n",
      "=================================================================\n",
      "Total params: 1,395,009\n",
      "Trainable params: 1,395,009\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "vocab_size = len(vocab) # 词表大小\n",
    "embedding_dim = 256\n",
    "rnn_units = 1024\n",
    "\n",
    "def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
    "    model = keras.models.Sequential([\n",
    "        keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape = [batch_size, None]),\n",
    "        # 每一步都要返回全部序列，以便和输出的序列做loss计算\n",
    "        keras.layers.SimpleRNN(units=rnn_units, return_sequences=True),\n",
    "        keras.layers.Dense(vocab_size)\n",
    "    ])\n",
    "    return model\n",
    "\n",
    "model = build_model(\n",
    "    vocab_size=vocab_size,\n",
    "    embedding_dim=embedding_dim,\n",
    "    rnn_units=rnn_units,\n",
    "    batch_size=batch_size)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 采样测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 100, 65)\n"
     ]
    }
   ],
   "source": [
    "# 没经过训练，已经可以进行预测，看一下model的输出\n",
    "for input_example_batch, target_example_batch in seq_dataset.take(1):\n",
    "    example_batch_predictions = model(input_example_batch)\n",
    "    print(example_batch_predictions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[40]\n",
      " [61]\n",
      " [ 0]\n",
      " [47]\n",
      " [29]\n",
      " [33]\n",
      " [32]\n",
      " [22]\n",
      " [48]\n",
      " [42]\n",
      " [ 1]\n",
      " [54]\n",
      " [10]\n",
      " [ 6]\n",
      " [25]\n",
      " [38]\n",
      " [43]\n",
      " [23]\n",
      " [43]\n",
      " [ 5]\n",
      " [59]\n",
      " [ 3]\n",
      " [ 0]\n",
      " [ 4]\n",
      " [ 0]\n",
      " [45]\n",
      " [55]\n",
      " [50]\n",
      " [37]\n",
      " [61]\n",
      " [13]\n",
      " [ 7]\n",
      " [23]\n",
      " [ 2]\n",
      " [21]\n",
      " [55]\n",
      " [20]\n",
      " [16]\n",
      " [42]\n",
      " [15]\n",
      " [64]\n",
      " [ 8]\n",
      " [60]\n",
      " [ 3]\n",
      " [56]\n",
      " [42]\n",
      " [36]\n",
      " [11]\n",
      " [29]\n",
      " [55]\n",
      " [43]\n",
      " [29]\n",
      " [ 5]\n",
      " [64]\n",
      " [ 7]\n",
      " [30]\n",
      " [56]\n",
      " [16]\n",
      " [18]\n",
      " [15]\n",
      " [64]\n",
      " [55]\n",
      " [37]\n",
      " [57]\n",
      " [12]\n",
      " [29]\n",
      " [15]\n",
      " [28]\n",
      " [62]\n",
      " [ 9]\n",
      " [46]\n",
      " [24]\n",
      " [64]\n",
      " [54]\n",
      " [33]\n",
      " [47]\n",
      " [18]\n",
      " [ 6]\n",
      " [37]\n",
      " [21]\n",
      " [12]\n",
      " [47]\n",
      " [29]\n",
      " [12]\n",
      " [44]\n",
      " [63]\n",
      " [51]\n",
      " [ 6]\n",
      " [ 1]\n",
      " [21]\n",
      " [41]\n",
      " [28]\n",
      " [49]\n",
      " [23]\n",
      " [33]\n",
      " [15]\n",
      " [57]\n",
      " [59]\n",
      " [11]\n",
      " [50]], shape=(100, 1), dtype=int64)\n",
      "tf.Tensor(\n",
      "[40 61  0 47 29 33 32 22 48 42  1 54 10  6 25 38 43 23 43  5 59  3  0  4\n",
      "  0 45 55 50 37 61 13  7 23  2 21 55 20 16 42 15 64  8 60  3 56 42 36 11\n",
      " 29 55 43 29  5 64  7 30 56 16 18 15 64 55 37 57 12 29 15 28 62  9 46 24\n",
      " 64 54 33 47 18  6 37 21 12 47 29 12 44 63 51  6  1 21 41 28 49 23 33 15\n",
      " 57 59 11 50], shape=(100,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "# 随机采样 random sampling 或者贪心采样 greedy 获得字符，从而生成一段话\n",
    "sample_indices = tf.random.categorical(logits=example_batch_predictions[0], num_samples=1)\n",
    "print(sample_indices) # 采样后的维度 (100, 65) -> (100, 1)\n",
    "\n",
    "# 将维度展平\n",
    "sample_indices = tf.squeeze(sample_indices, axis = -1)\n",
    "print(sample_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: \"ment.\\nYou thus employ'd, I will go root away\\nThe noisome weeds, which without profit suck\\nThe soil's\"\n",
      "\n",
      "Output: \"ent.\\nYou thus employ'd, I will go root away\\nThe noisome weeds, which without profit suck\\nThe soil's \"\n",
      "\n",
      "Predictions: \"bw\\niQUTJjd p:,MZeKe'u$\\n&\\ngqlYwA-K!IqHDdCz.v$rdX;QqeQ'z-RrDFCzqYs?QCPx3hLzpUiF,YI?iQ?fym, IcPkKUCsu;l\"\n"
     ]
    }
   ],
   "source": [
    "# 打印输入、输出、预测\n",
    "print('Input:', repr(''.join(idx2char[input_example_batch[0]])))\n",
    "print()\n",
    "print('Output:', repr(''.join(idx2char[target_example_batch[0]])))\n",
    "print()\n",
    "print('Predictions:', repr(''.join(idx2char[sample_indices])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 100)\n",
      "4.1961346\n"
     ]
    }
   ],
   "source": [
    "# 自定义损失函数，loss从logits里来\n",
    "def loss(labels, logits):\n",
    "    return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
    "\n",
    "model.compile(optimizer = 'adam', loss = loss)\n",
    "\n",
    "example_loss = loss(target_example_batch, example_batch_predictions)\n",
    "print(example_loss.shape)\n",
    "print(example_loss.numpy().mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "172/172 [==============================] - 39s 225ms/step - loss: 2.7038\n",
      "Epoch 2/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 2.0595\n",
      "Epoch 3/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.8420\n",
      "Epoch 4/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.7025\n",
      "Epoch 5/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.6059\n",
      "Epoch 6/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.5397\n",
      "Epoch 7/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.4898\n",
      "Epoch 8/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.4510\n",
      "Epoch 9/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.4209\n",
      "Epoch 10/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.3952\n",
      "Epoch 11/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.3709\n",
      "Epoch 12/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.3505\n",
      "Epoch 13/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.3308\n",
      "Epoch 14/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.3145\n",
      "Epoch 15/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.2995\n",
      "Epoch 16/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.2841\n",
      "Epoch 17/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.2699\n",
      "Epoch 18/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.2566\n",
      "Epoch 19/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.2412\n",
      "Epoch 20/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.2303\n",
      "Epoch 21/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.2176\n",
      "Epoch 22/100\n",
      "172/172 [==============================] - 37s 215ms/step - loss: 1.2039\n",
      "Epoch 23/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.1918\n",
      "Epoch 24/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.1785\n",
      "Epoch 25/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.1691\n",
      "Epoch 26/100\n",
      "172/172 [==============================] - 37s 215ms/step - loss: 1.1561\n",
      "Epoch 27/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.1465\n",
      "Epoch 28/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 1.1324\n",
      "Epoch 29/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.1223\n",
      "Epoch 30/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.1101\n",
      "Epoch 31/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.1019\n",
      "Epoch 32/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.0889\n",
      "Epoch 33/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.0786\n",
      "Epoch 34/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.0705\n",
      "Epoch 35/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0601\n",
      "Epoch 36/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.0507\n",
      "Epoch 37/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.0423\n",
      "Epoch 38/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.0361\n",
      "Epoch 39/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 1.0252\n",
      "Epoch 40/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0207\n",
      "Epoch 41/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0104\n",
      "Epoch 42/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.00482s - l\n",
      "Epoch 43/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9966\n",
      "Epoch 44/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9957\n",
      "Epoch 45/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9874\n",
      "Epoch 46/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9820\n",
      "Epoch 47/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 0.9773\n",
      "Epoch 48/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9743\n",
      "Epoch 49/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9711\n",
      "Epoch 50/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9687\n",
      "Epoch 51/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9615\n",
      "Epoch 52/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9616\n",
      "Epoch 53/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9603\n",
      "Epoch 54/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9595\n",
      "Epoch 55/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 0.9544\n",
      "Epoch 56/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9520\n",
      "Epoch 57/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 0.9541\n",
      "Epoch 58/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9501\n",
      "Epoch 59/100\n",
      "172/172 [==============================] - 38s 218ms/step - loss: 0.9473\n",
      "Epoch 60/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9470\n",
      "Epoch 61/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9473\n",
      "Epoch 62/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.94950s - loss: 0.\n",
      "Epoch 63/100\n",
      "172/172 [==============================] - 38s 222ms/step - loss: 0.9492\n",
      "Epoch 64/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9508\n",
      "Epoch 65/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9494\n",
      "Epoch 66/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9527\n",
      "Epoch 67/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9535\n",
      "Epoch 68/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9512\n",
      "Epoch 69/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9501\n",
      "Epoch 70/100\n",
      "172/172 [==============================] - 38s 221ms/step - loss: 0.9569\n",
      "Epoch 71/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 0.9606\n",
      "Epoch 72/100\n",
      "172/172 [==============================] - 38s 222ms/step - loss: 0.9612\n",
      "Epoch 73/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 0.9628\n",
      "Epoch 74/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 0.9628\n",
      "Epoch 75/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9675\n",
      "Epoch 76/100\n",
      "172/172 [==============================] - 38s 218ms/step - loss: 0.9680\n",
      "Epoch 77/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9712\n",
      "Epoch 78/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9720\n",
      "Epoch 79/100\n",
      "172/172 [==============================] - 37s 216ms/step - loss: 0.9731\n",
      "Epoch 80/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9802\n",
      "Epoch 81/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9811\n",
      "Epoch 82/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9837\n",
      "Epoch 83/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9852\n",
      "Epoch 84/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9904\n",
      "Epoch 85/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 0.9978\n",
      "Epoch 86/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 0.9925\n",
      "Epoch 87/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 0.9951\n",
      "Epoch 88/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0041\n",
      "Epoch 89/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0112\n",
      "Epoch 90/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 1.0157\n",
      "Epoch 91/100\n",
      "172/172 [==============================] - 38s 219ms/step - loss: 1.0133\n",
      "Epoch 92/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.0185\n",
      "Epoch 93/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.0253\n",
      "Epoch 94/100\n",
      "172/172 [==============================] - 38s 220ms/step - loss: 1.0281\n",
      "Epoch 95/100\n",
      "172/172 [==============================] - 37s 217ms/step - loss: 1.0293\n",
      "Epoch 96/100\n",
      "172/172 [==============================] - 38s 218ms/step - loss: 1.0415\n",
      "Epoch 97/100\n",
      "172/172 [==============================] - 38s 218ms/step - loss: 1.0373\n",
      "Epoch 98/100\n",
      "172/172 [==============================] - 37s 218ms/step - loss: 1.0442\n",
      "Epoch 99/100\n",
      "172/172 [==============================] - 38s 218ms/step - loss: 1.0509\n",
      "Epoch 100/100\n",
      "172/172 [==============================] - 38s 222ms/step - loss: 1.0540\n"
     ]
    }
   ],
   "source": [
    "# 定义文件保存模型\n",
    "output_dir = 'text_generation_checkpoint'\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')\n",
    "\n",
    "# 定义参数\n",
    "checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
    "    filepath = checkpoint_prefix, save_weights_only = True)\n",
    "epochs = 100\n",
    "callbacks = [\n",
    "    checkpoint_callback,\n",
    "    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)\n",
    "]\n",
    "\n",
    "# 训练\n",
    "history = model.fit(seq_dataset, epochs = epochs, callbacks = [checkpoint_callback])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'text_generation_checkpoint\\\\ckpt_100'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.train.latest_checkpoint(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 载入模型权重，进行文本生成\n",
    "文本生成的流程：\n",
    "- start ch sequence A\n",
    "- A -> model -> b\n",
    "- A.append(b) -> B\n",
    "- B(Ab) -> model -> c\n",
    "- B.append(c) -> C\n",
    "- C(Abc) -> model -> ...一直循环结束"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (1, None, 256)            16640     \n",
      "_________________________________________________________________\n",
      "simple_rnn_1 (SimpleRNN)     (1, None, 1024)           1311744   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (1, None, 65)             66625     \n",
      "=================================================================\n",
      "Total params: 1,395,009\n",
      "Trainable params: 1,395,009\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model2 = build_model(\n",
    "    vocab_size,\n",
    "    embedding_dim,\n",
    "    rnn_units,\n",
    "    batch_size = 1)\n",
    "# 加载权重\n",
    "model2.load_weights(tf.train.latest_checkpoint(output_dir))\n",
    "# 定义输入大小\n",
    "model2.build(tf.TensorShape([1, None]))\n",
    "\n",
    "model2.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 默认生成1000个字符的文本\n",
    "def generate_text(model, start_string, num_generate = 1000):\n",
    "    # 字符变id\n",
    "    input_eval = [char2idx[ch] for ch in start_string]\n",
    "    # 扩展一个维度，一维变二维\n",
    "    input_eval = tf.expand_dims(input_eval, 0)\n",
    "    \n",
    "    text_generated = []\n",
    "    model.reset_states()\n",
    "    \n",
    "    # 生成流程：\n",
    "    # 1. model inference -> predictions\n",
    "    # 2. sample -> ch -> text_generated.\n",
    "    # 3. update input_eval\n",
    "    for _ in range(num_generate):\n",
    "        # 模型预测 predictions : [batch_size, input_eval_len, vocab_size]\n",
    "        predictions = model(input_eval)\n",
    "        # 将第0维消掉\n",
    "        predictions = tf.squeeze(predictions, 0)\n",
    "        # 随机采样，predictions_ids: [input_eval_len, 1]\n",
    "        # 取出预测值中的最后一个字符 a b c -> b c d ,所以取出d\n",
    "        predicted_id = tf.random.categorical(predictions, num_samples = 1)[-1, 0].numpy()\n",
    "        text_generated.append(idx2char[predicted_id])\n",
    "        # 将input_eval替换成predictions，并扩维\n",
    "        input_eval = tf.expand_dims([predicted_id], 0)\n",
    "        \n",
    "    return start_string + ''.join(text_generated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All:\n",
      "\n",
      "E: at.\n",
      "Thold ondiny thesld:\n",
      "'dscoorar tot,'\n",
      "\n",
      "TRK:\n",
      "TINUDo; metthy g K:\n",
      "ERLIORAN ig Her'shay aurou hing be wisu tho k!\n",
      "\n",
      "\n",
      "\n",
      "Whave shw k that me w alinghackitowhinamys halll it, s lerositusterther prtionthr,\n",
      "A thieade mannthepe sou beng,'llw ame ws windok.\n",
      "\n",
      "T:\n",
      "M:\n",
      "MNTI' wif meng d d bed.\n",
      "Hered My hal ghen io avind goul lld.\n",
      "Cithed ORE:\n",
      "\n",
      "USAMesst, w w bure IE:\n",
      "I boulyorirouconoume in, whis fy vessthelibes siou f ar.\n",
      "\n",
      "OMy cucl s Wes t n:\n",
      "Thesherel anson,\n",
      "CLI:\n",
      "CHeraige?\n",
      "ICibe f INondg, y, veis.\n",
      "Ofuthadsthan moven:\n",
      "ELecin,\n",
      "TERUK:\n",
      "He RUCow akirde.\n",
      "Br\n",
      "ANULee se\n",
      "AMI athal gou s cosu the\n",
      "\n",
      "Cikir divead t helio t, anthotie ceny br?\n",
      "Pldomy w l f O:\n",
      "A:\n",
      "fe llf hiratha'dse wind bath beng the\n",
      "\n",
      "\n",
      "Sheave'ded wno Whe ld s, bes cerd ire wie d ate; e couthame fonsth;\n",
      "SAUERWhestheny lend J\n",
      "Thangredas marsthienso y an Wing, atoulertss llighigrdloux:\n",
      "wind blo Inits ht. if, g, pr TIOLer-whinou INNThe ETERCAt, I brs thng omy artely A: med vee hathenire:\n",
      "\n",
      "Anonee ig f IICLI\n",
      "ALO: bll she omowand, p thou,\n",
      "\n",
      "PE:\n",
      "Ny LAngi\n"
     ]
    }
   ],
   "source": [
    "new_text = generate_text(model2, 'All:')\n",
    "print(new_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
